BATS forecaster (multiple seasonality)#

BATS stands for Box–Cox transform, ARMA errors, Trend, and Seasonal components.

This notebook builds a practical BATS-style forecaster that supports multiple seasonalities (e.g., weekly + monthly), with a scikit-learn-like API:

  • BATS(use_box_cox=..., box_cox_bounds=..., use_trend=..., use_damped_trend=..., seasonal_periods=..., use_arma_errors=...)

  • model = bats.fit(y)

  • forecast = model.forecast(steps)

Implementation note: the original BATS model is formulated in state-space / exponential-smoothing form. Here we implement a BATS-style forecaster using:

  • explicit trend + seasonal design matrices, and

  • ARMA errors estimated via statsmodels (SARIMAX with d=0).

Model sketch (math)#

Let \(y_t\) be the observed series.

Box–Cox transform (optional)#

For \(y_t>0\) and parameter \(\lambda\): $\(g_\lambda(y_t) = \begin{cases} \dfrac{y_t^{\lambda}-1}{\lambda}, & \lambda \ne 0 \\ \log(y_t), & \lambda = 0 \end{cases}\)$

We model the transformed series \(x_t = g_\lambda(y_t)\).

Trend + multiple seasonalities#

\[x_t = \underbrace{\beta_0}_{\text{level}} + \underbrace{\beta_1\,f(t)}_{\text{trend (optional)}} + \sum_{k=1}^{K} S^{(k)}_t + u_t\]
  • \(f(t)=t\) for a linear trend.

  • For a simple damped trend option we use \(f(t)=\dfrac{1-\phi^t}{1-\phi}\) with damping \(\phi\in(0,1)\).

  • Each seasonal component \(S^{(k)}_t\) is encoded with seasonal dummies for period \(m_k\).

ARMA errors (optional)#

\[u_t = \sum_{i=1}^p \varphi_i u_{t-i} + \varepsilon_t + \sum_{j=1}^q \vartheta_j\varepsilon_{t-j}, \qquad \varepsilon_t\sim\text{WN}(0,\sigma^2).\]
import warnings

import numpy as np
import pandas as pd

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import os
import plotly.io as pio

from scipy import stats
import statsmodels.api as sm

warnings.filterwarnings("ignore", category=UserWarning)

pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
pio.templates.default = "plotly_white"

rng = np.random.default_rng(7)

import numpy, pandas, scipy, statsmodels, plotly
print("numpy:", numpy.__version__)
print("pandas:", pandas.__version__)
print("scipy:", scipy.__version__)
print("statsmodels:", statsmodels.__version__)
print("plotly:", plotly.__version__)
numpy: 1.26.2
pandas: 2.1.3
scipy: 1.15.0
statsmodels: 0.14.4
plotly: 6.5.2
class BoxCoxTransformer:
    def __init__(self, use_box_cox: bool, box_cox_bounds: tuple[float, float] = (0.0, 1.0)):
        self.use_box_cox = bool(use_box_cox)
        self.box_cox_bounds = tuple(float(v) for v in box_cox_bounds)
        self.shift_: float = 0.0
        self.lambda_: float | None = None

    def fit(self, y: np.ndarray) -> "BoxCoxTransformer":
        y = np.asarray(y, dtype=float)
        if not self.use_box_cox:
            self.shift_ = 0.0
            self.lambda_ = None
            return self

        min_y = float(np.min(y))
        self.shift_ = 0.0 if min_y > 0.0 else (1.0 - min_y)
        y_pos = y + self.shift_
        if np.any(y_pos <= 0.0):
            raise ValueError("Box-Cox requires strictly positive data (even after shift).")

        lo, hi = self.box_cox_bounds
        self.lambda_ = float(stats.boxcox_normmax(y_pos, brack=(lo, hi), method="mle"))
        return self

    def transform(self, y: np.ndarray) -> np.ndarray:
        y = np.asarray(y, dtype=float)
        if not self.use_box_cox:
            return y.copy()
        if self.lambda_ is None:
            raise RuntimeError("Call fit() before transform().")

        y_pos = y + self.shift_
        if np.any(y_pos <= 0.0):
            raise ValueError("Box-Cox requires strictly positive data (even after shift).")

        lmbda = float(self.lambda_)
        if abs(lmbda) < 1e-10:
            return np.log(y_pos)
        return (np.power(y_pos, lmbda) - 1.0) / lmbda

    def inverse_transform(self, x: np.ndarray) -> np.ndarray:
        x = np.asarray(x, dtype=float)
        if not self.use_box_cox:
            return x.copy()
        if self.lambda_ is None:
            raise RuntimeError("Call fit() before inverse_transform().")

        lmbda = float(self.lambda_)
        if abs(lmbda) < 1e-10:
            y_pos = np.exp(x)
        else:
            y_pos = np.power(lmbda * x + 1.0, 1.0 / lmbda)
        return y_pos - self.shift_


def _acf(x: np.ndarray, max_lag: int) -> tuple[np.ndarray, np.ndarray]:
    x = np.asarray(x, dtype=float)
    x = x - x.mean()
    denom = float(np.dot(x, x))
    lags = np.arange(max_lag + 1)
    values = np.zeros(max_lag + 1)
    values[0] = 1.0
    if denom == 0.0:
        return lags, values
    for k in range(1, max_lag + 1):
        values[k] = float(np.dot(x[k:], x[:-k]) / denom)
    return lags, values


def seasonal_dummies(t: np.ndarray, period: int, *, drop_first: bool = True) -> np.ndarray:
    """One-hot seasonal indicators for t mod period."""
    t = np.asarray(t, dtype=int)
    period = int(period)
    if period <= 1:
        return np.zeros((t.size, 0), dtype=float)

    pos = t % period
    n = t.size
    k = period - 1 if drop_first else period
    X = np.zeros((n, k), dtype=float)
    if drop_first:
        mask = pos != 0
        X[np.arange(n)[mask], pos[mask] - 1] = 1.0
    else:
        X[np.arange(n), pos] = 1.0
    return X


def trend_feature(t: np.ndarray, *, use_damped: bool, damped_phi: float) -> np.ndarray:
    t = np.asarray(t, dtype=float)
    if not use_damped:
        return t
    phi = float(damped_phi)
    if not (0.0 < phi < 1.0):
        raise ValueError("damped_phi must be in (0, 1)")
    # f(t) = (1 - phi^t) / (1 - phi), with f(0)=0 and f(t) ~ t when phi -> 1
    return (1.0 - np.power(phi, t)) / (1.0 - phi)


def bats_design_matrix(
    t: np.ndarray,
    *,
    use_trend: bool,
    use_damped_trend: bool,
    damped_trend_phi: float,
    seasonal_periods: list[int] | None,
) -> np.ndarray:
    t = np.asarray(t, dtype=int)
    cols = [np.ones((t.size, 1), dtype=float)]
    if use_trend:
        cols.append(trend_feature(t.astype(float), use_damped=use_damped_trend, damped_phi=damped_trend_phi).reshape(-1, 1))
    if seasonal_periods:
        for m in seasonal_periods:
            cols.append(seasonal_dummies(t, period=int(m), drop_first=True))
    return np.concatenate(cols, axis=1)
class BATSModel:
    def __init__(
        self,
        *,
        results,
        transformer: BoxCoxTransformer,
        use_trend: bool,
        use_damped_trend: bool,
        damped_trend_phi: float,
        seasonal_periods: list[int] | None,
        y_index,
    ):
        self.results = results
        self.transformer = transformer
        self.use_trend = use_trend
        self.use_damped_trend = use_damped_trend
        self.damped_trend_phi = float(damped_trend_phi)
        self.seasonal_periods = seasonal_periods
        self.y_index = y_index

    @property
    def n_obs(self) -> int:
        return int(self.results.nobs)

    def fitted_values(self) -> np.ndarray:
        fitted_x = np.asarray(self.results.fittedvalues, dtype=float)
        return self.transformer.inverse_transform(fitted_x)

    def residuals(self) -> np.ndarray:
        # Residuals in the transformed space (more natural under Box-Cox)
        return np.asarray(self.results.resid, dtype=float)

    def forecast(self, steps: int, *, alpha: float = 0.05) -> dict[str, np.ndarray]:
        steps = int(steps)
        t_future = np.arange(self.n_obs, self.n_obs + steps)
        X_future = bats_design_matrix(
            t_future,
            use_trend=self.use_trend,
            use_damped_trend=self.use_damped_trend,
            damped_trend_phi=self.damped_trend_phi,
            seasonal_periods=self.seasonal_periods,
        )

        fcst = self.results.get_forecast(steps=steps, exog=X_future)
        mean_x = np.asarray(fcst.predicted_mean, dtype=float)

        ci = fcst.conf_int(alpha=alpha)
        ci_np = np.asarray(ci)
        lower_x = ci_np[:, 0]
        upper_x = ci_np[:, 1]

        mean_y = self.transformer.inverse_transform(mean_x)
        lower_y = self.transformer.inverse_transform(lower_x)
        upper_y = self.transformer.inverse_transform(upper_x)

        return {"mean": mean_y, "lower": lower_y, "upper": upper_y}


class BATS:
    def __init__(
        self,
        *,
        use_box_cox: bool = False,
        box_cox_bounds: tuple[float, float] = (0.0, 1.0),
        use_trend: bool = True,
        use_damped_trend: bool = False,
        damped_trend_phi: float = 0.98,
        seasonal_periods: list[int] | None = None,
        use_arma_errors: bool = True,
        arma_order: tuple[int, int] | None = (1, 1),
        max_arma_order: int = 1,
        show_warnings: bool = True,
    ):
        self.use_box_cox = bool(use_box_cox)
        self.box_cox_bounds = tuple(float(v) for v in box_cox_bounds)
        self.use_trend = bool(use_trend)
        self.use_damped_trend = bool(use_damped_trend)
        self.damped_trend_phi = float(damped_trend_phi)
        self.seasonal_periods = None if seasonal_periods is None else [int(m) for m in seasonal_periods]
        self.use_arma_errors = bool(use_arma_errors)
        self.arma_order = None if arma_order is None else (int(arma_order[0]), int(arma_order[1]))
        self.max_arma_order = int(max_arma_order)
        self.show_warnings = bool(show_warnings)

    def _fit_sarimax(self, y_x: np.ndarray, X: np.ndarray, order: tuple[int, int]) -> tuple[object, float]:
        p, q = order
        res = sm.tsa.SARIMAX(
            y_x,
            exog=X,
            order=(p, 0, q),
            trend="n",
            enforce_stationarity=True,
            enforce_invertibility=True,
        ).fit(disp=False, method="lbfgs", maxiter=300)
        return res, float(res.aic)

    def _select_arma_order(self, y_x: np.ndarray, X: np.ndarray) -> tuple[int, int]:
        candidates = []
        for p in range(self.max_arma_order + 1):
            for q in range(self.max_arma_order + 1):
                candidates.append((p, q))

        best_order = (0, 0)
        best_aic = np.inf

        for order in candidates:
            try:
                _, aic = self._fit_sarimax(y_x, X, order)
            except Exception:
                continue
            if aic < best_aic:
                best_aic = aic
                best_order = order

        if best_aic == np.inf:
            raise RuntimeError("Failed to fit any ARMA(p,q) candidate.")
        return best_order

    def fit(self, y) -> BATSModel:
        if isinstance(y, pd.Series):
            y_index = y.index
            y_np = y.to_numpy(dtype=float)
        else:
            y_index = None
            y_np = np.asarray(y, dtype=float)

        t = np.arange(y_np.size)
        X = bats_design_matrix(
            t,
            use_trend=self.use_trend,
            use_damped_trend=self.use_damped_trend,
            damped_trend_phi=self.damped_trend_phi,
            seasonal_periods=self.seasonal_periods,
        )

        transformer = BoxCoxTransformer(self.use_box_cox, box_cox_bounds=self.box_cox_bounds).fit(y_np)
        y_x = transformer.transform(y_np)

        if not self.use_arma_errors:
            chosen_order = (0, 0)
        elif self.arma_order is not None:
            chosen_order = self.arma_order
        else:
            chosen_order = self._select_arma_order(y_x, X)

        res, aic = self._fit_sarimax(y_x, X, chosen_order)
        if self.show_warnings:
            print(f"Chosen ARMA(p,q) = {chosen_order}, AIC = {aic:.2f}")

        return BATSModel(
            results=res,
            transformer=transformer,
            use_trend=self.use_trend,
            use_damped_trend=self.use_damped_trend,
            damped_trend_phi=self.damped_trend_phi,
            seasonal_periods=self.seasonal_periods,
            y_index=y_index,
        )

Demo: synthetic series with two seasonalities#

We’ll simulate a daily series with:

  • weekly seasonality (\(m_1=7\))

  • ~monthly seasonality (\(m_2=30\))

  • a small trend

  • ARMA-like correlated noise

def simulate_arma11(n: int, *, phi: float, theta: float, sigma: float, rng: np.random.Generator) -> np.ndarray:
    eps = rng.normal(0.0, sigma, size=n)
    u = np.zeros(n)
    for t in range(n):
        ar = phi * u[t - 1] if t - 1 >= 0 else 0.0
        ma = theta * eps[t - 1] if t - 1 >= 0 else 0.0
        u[t] = ar + eps[t] + ma
    return u


n = 420
idx = pd.date_range("2020-01-01", periods=n, freq="D")
t = np.arange(n)

weekly = 2.0 * np.sin(2 * np.pi * t / 7) + 0.5 * np.cos(2 * np.pi * t / 7)
monthly = 1.2 * np.sin(2 * np.pi * t / 30) - 0.3 * np.cos(2 * np.pi * t / 30)
trend = 0.01 * t

noise = simulate_arma11(n, phi=0.6, theta=0.4, sigma=0.8, rng=rng)

y = 30.0 + trend + weekly + monthly + noise
y = pd.Series(y, index=idx, name="y")

fig = go.Figure()
fig.add_trace(go.Scatter(x=y.index, y=y, name="y", line=dict(color="black")))
fig.update_layout(title="Synthetic multi-seasonal series", xaxis_title="date", yaxis_title="value")
fig.show()
# Train/test split + fit
h = 60
y_train = y.iloc[:-h]
y_test = y.iloc[-h:]

bats = BATS(
    use_box_cox=False,
    box_cox_bounds=(0.0, 1.0),
    use_trend=True,
    use_damped_trend=False,
    seasonal_periods=[7, 30],
    use_arma_errors=True,
    arma_order=(1, 1),
    show_warnings=True,
)

model = bats.fit(y_train)
fcst = model.forecast(h)

fitted = pd.Series(model.fitted_values(), index=y_train.index)
pred_mean = pd.Series(fcst["mean"], index=y_test.index)
pred_lower = pd.Series(fcst["lower"], index=y_test.index)
pred_upper = pd.Series(fcst["upper"], index=y_test.index)

fig = go.Figure()
fig.add_trace(go.Scatter(x=y_train.index, y=y_train, name="train", line=dict(color="rgba(0,0,0,0.35)")))
fig.add_trace(go.Scatter(x=y_train.index, y=fitted, name="fitted", line=dict(color="#4E79A7")))
fig.add_trace(go.Scatter(x=y_test.index, y=y_test, name="test", line=dict(color="black")))

fig.add_trace(go.Scatter(x=y_test.index, y=pred_upper, line=dict(width=0), showlegend=False))
fig.add_trace(
    go.Scatter(
        x=y_test.index,
        y=pred_lower,
        fill="tonexty",
        fillcolor="rgba(78,121,167,0.18)",
        line=dict(width=0),
        name="95% interval (approx)",
    )
)
fig.add_trace(go.Scatter(x=y_test.index, y=pred_mean, name="forecast mean", line=dict(color="#E15759")))

fig.update_layout(title="BATS forecast on multi-seasonal series", xaxis_title="date", yaxis_title="value")
fig.show()
Chosen ARMA(p,q) = (1, 1), AIC = 841.30
# Residual diagnostics (in transformed space)
resid = model.residuals()
warmup = 10
resid_use = resid[warmup:]

print("residual mean:", float(resid_use.mean()))
print("residual std:", float(resid_use.std(ddof=1)))
print("Jarque-Bera:", stats.jarque_bera(resid_use))

lags, acf_vals = _acf(resid_use, max_lag=30)
bound = 1.96 / np.sqrt(resid_use.size)

# QQ data
nq = resid_use.size
p = (np.arange(1, nq + 1) - 0.5) / nq
theoretical = stats.norm.ppf(p)
sample_q = np.sort((resid_use - resid_use.mean()) / resid_use.std(ddof=1))

fig = make_subplots(
    rows=2,
    cols=2,
    subplot_titles=("Residuals over time", "Residual histogram", "Residual ACF", "QQ plot (std residuals)"),
)

fig.add_trace(go.Scatter(x=y_train.index[warmup:], y=resid_use, name="residuals", line=dict(color="#4E79A7")), row=1, col=1)
fig.add_hline(y=0, line=dict(color="black", dash="dash"), row=1, col=1)

fig.add_trace(go.Histogram(x=resid_use, nbinsx=30, name="hist", marker_color="#4E79A7"), row=1, col=2)

fig.add_trace(go.Bar(x=lags, y=acf_vals, name="ACF(resid)", marker_color="#4E79A7"), row=2, col=1)
fig.add_trace(go.Scatter(x=[0, lags.max()], y=[bound, bound], mode="lines", line=dict(color="gray", dash="dash"), showlegend=False), row=2, col=1)
fig.add_trace(go.Scatter(x=[0, lags.max()], y=[-bound, -bound], mode="lines", line=dict(color="gray", dash="dash"), showlegend=False), row=2, col=1)

fig.add_trace(go.Scatter(x=theoretical, y=sample_q, mode="markers", name="QQ", marker=dict(color="#4E79A7")), row=2, col=2)
fig.add_trace(
    go.Scatter(x=[theoretical.min(), theoretical.max()], y=[theoretical.min(), theoretical.max()], mode="lines", line=dict(color="black", dash="dash"), showlegend=False),
    row=2,
    col=2,
)

fig.update_layout(height=750, title="BATS residual diagnostics")
fig.show()
residual mean: -0.003120780574746758
residual std: 0.7024190892644842
Jarque-Bera: SignificanceResult(statistic=0.20620374582574355, pvalue=0.9020350758599663)